{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "convlayer = torch.nn.Conv1d(1,256,5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.randint(1,10,(1,20)).type(torch.float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 6., 1., 3., 6., 9., 4., 7., 3., 9., 7., 5., 3., 9., 3., 6., 9., 5.,\n",
      "         4., 7.]])\n"
     ]
    }
   ],
   "source": [
    "print(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 1, 20])\n"
     ]
    }
   ],
   "source": [
    "print(x.unsqueeze(0).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CausalConv1d(torch.nn.Conv1d):\n",
    "    def __init__(self,\n",
    "                 in_channels,\n",
    "                 out_channels,\n",
    "                 kernel_size,\n",
    "                 stride=1,\n",
    "                 dilation=1,\n",
    "                 groups=1,\n",
    "                 bias=True):\n",
    "\n",
    "        super(CausalConv1d, self).__init__(\n",
    "            in_channels,\n",
    "            out_channels,\n",
    "            kernel_size=kernel_size,\n",
    "            stride=stride,\n",
    "            padding=0,\n",
    "            dilation=dilation,\n",
    "            groups=groups,\n",
    "            bias=bias)\n",
    "        \n",
    "        self.__padding = (kernel_size - 1) * dilation\n",
    "        \n",
    "    def forward(self, input):\n",
    "        return super(CausalConv1d, self).forward(F.pad(input, (self.__padding, 0)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "kernel_size = 3\n",
    "in_channels=1\n",
    "out_channels=1\n",
    "dilation = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "convlayer = CausalConv1d(in_channels,out_channels,kernel_size,dilation=dilation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 1, 20])\n",
      "tensor([[[-0.0975, -0.2860,  2.2729, -1.9973,  0.2712,  0.5473,  1.5007,\n",
      "          -2.1990,  1.4213, -2.2102,  2.1639, -0.5824, -0.5335, -1.5114,\n",
      "           2.6773, -2.5240,  0.5473,  1.3724, -1.3607, -0.8014]]],\n",
      "       grad_fn=<SqueezeBackward1>)\n"
     ]
    }
   ],
   "source": [
    "print(convlayer(x.unsqueeze(0)).shape)\n",
    "print(convlayer(x.unsqueeze(0)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "__padding = (kernel_size - 1) * dilation\n",
    "x_pad = F.pad(x, (__padding, 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(-0.0975, grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.dot(convlayer.weight[0][0],x_pad[0][0:3])+convlayer.bias[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class context_embedding(torch.nn.Module):\n",
    "    def __init__(self,in_channels=1,embedding_size=256,k=5):\n",
    "        super(context_embedding,self).__init__()\n",
    "        self.causal_convolution = CausalConv1d(in_channels,embedding_size,kernel_size=k)\n",
    "\n",
    "    def forward(self,x):\n",
    "        x = self.causal_convolution(x)\n",
    "        return F.sigmoid(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding = context_embedding(1,256,5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 256, 20])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding(x.unsqueeze(0)).shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
